Compare binding to human versus mouse Mxra8¶

In [1]:
import itertools

import altair as alt

import pandas as pd

_ = alt.data_transformers.disable_max_rows()
In [2]:
# this cell is tagged parameters for `papermill` parameterization

entry_293T_human_Mxra8 = None
binding_human_Mxra8 = None
binding_mouse_Mxra8 = None
addtl_site_annotations = None
site_numbering_map = None

mut_corr_chart_html = None
site_corr_chart_html = None
site_chart_html = None
In [3]:
# Parameters
entry_csv = "results/func_effects/averages/293T-Mxra8_entry_func_effects.csv"
binding_human_Mxra8 = "results/receptor_affinity/averages/human_Mxra8_mut_effect.csv"
binding_mouse_Mxra8 = "results/receptor_affinity/averages/mouse_Mxra8_mut_effect.csv"
addtl_site_annotations = "data/addtl_site_annotations.csv"
site_numbering_map = "data/site_numbering_map.csv"
mut_corr_chart_html = "results/compare_human_mouse_mxra8_mut_binding_corr.html"
site_corr_chart_html = "results/compare_human_mouse_mxra8_site_binding_corr.html"
site_chart_html = "results/compare_human_mouse_mxra8_site_chart.html"
In [4]:
# Additional hardcoded parameters

min_entry = -4
min_entry_std = 2.25
entry_name = "entry in 293T-Mxra8 cells"
min_times_seen = 2

ligands = {"mouse_Mxra8": "mouse Mxra8", "human_Mxra8": "human Mxra8"}
binding_csvs = {
    "human_Mxra8": binding_human_Mxra8,
    "mouse_Mxra8": binding_mouse_Mxra8,
}
binding_csv_col_names = {"human_Mxra8": "Mxra8", "mouse_Mxra8": "Mxra8"}
max_binding_stds = {"human_Mxra8": 2.5, "mouse_Mxra8": 2.25}

addtl_site_annotations_cols = {
    "domain": "domain",
    "contacts": "Mxra8 contact",
}

assert len(ligands) == 2, "saving for corr charts only works for 2 ligands currently"

Read the data¶

In [5]:
# read the data

print(f"Reading cell entry from {entry_csv=}")
data_df = (
    pd.read_csv(entry_csv)
    .query("times_seen >= @min_times_seen")
    .query("effect_std <= @min_entry_std")
    .assign(mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"])
    [["site", "wildtype", "mutant", "effect"]]
    .rename(columns={"effect": "entry"})
)

for ligand in ligands:
    print(f"Reading binding to {ligand=} from {binding_csvs[ligand]=}")
    max_std = max_binding_stds[ligand]
    col_name = binding_csv_col_names[ligand]
    bind_df = (
        pd.read_csv(binding_csvs[ligand])
        .query("times_seen >= @min_times_seen")
        .query("frac_models == 1")
        .query(f"`{col_name} binding_std` <= @max_std")
        .rename(columns={f"{col_name} binding_median": ligand})
    )
    bind_rep_cols = bind_df.columns[11: ].tolist()
    bind_df = (
        bind_df
        .assign(
            label=lambda x: x.apply(
                lambda r: f"{r[ligand]:.2f} ({', '.join(str(round(r[c], 2)) for c in bind_rep_cols)})",
                axis=1,
            )
        )
        .rename(columns={"label": f"{ligand}_label"})
        [["site", "wildtype", "mutant", ligand, f"{ligand}_label"]]
    )
    data_df = data_df.merge(
        bind_df, how="left", on=["site", "mutant", "wildtype"], validate="1:1"
    )

print(f"Adding sequential site from {site_numbering_map=}")
data_df = data_df.merge(
    pd.read_csv(site_numbering_map).rename(columns={"reference_site": "site"})[
        ["site", "sequential_site", "region"]
    ],
    on="site",
    validate="many_to_one",
)

print(f"Adding site annotations from {addtl_site_annotations=}")
data_df = data_df.merge(
    (
        pd.read_csv(addtl_site_annotations)
        [["sequential_site"] + list(addtl_site_annotations_cols)]
        .rename(columns=addtl_site_annotations_cols)
    ),
    on="sequential_site",
    validate="many_to_one",
    how="left",
)

data_df = (
    data_df
    .query("wildtype != mutant")
    .assign(
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
        **{"Mxra8 contact": lambda x: x["Mxra8 contact"].fillna("no")},
    )
    .sort_values(["sequential_site", "mutant"])
    .reset_index(drop=True)
)

data_df
Reading cell entry from entry_csv='results/func_effects/averages/293T-Mxra8_entry_func_effects.csv'
Reading binding to ligand='mouse_Mxra8' from binding_csvs[ligand]='results/receptor_affinity/averages/mouse_Mxra8_mut_effect.csv'
Reading binding to ligand='human_Mxra8' from binding_csvs[ligand]='results/receptor_affinity/averages/human_Mxra8_mut_effect.csv'
Adding sequential site from site_numbering_map='data/site_numbering_map.csv'
Adding site annotations from addtl_site_annotations='data/addtl_site_annotations.csv'
Out[5]:
site wildtype mutant entry mouse_Mxra8 mouse_Mxra8_label human_Mxra8 human_Mxra8_label sequential_site region domain Mxra8 contact mutation
0 -1(E3) M I -7.5410 NaN NaN NaN NaN 1 E3 NaN no M-1(E3)I
1 -1(E3) M T -7.5630 NaN NaN NaN NaN 1 E3 NaN no M-1(E3)T
2 1(E3) S A -1.0250 -0.11910 -0.12 (-0.06, -0.18) 0.04762 0.05 (0.06, 0.03) 2 E3 E3 no S1(E3)A
3 1(E3) S C -0.7132 -0.21170 -0.21 (-0.44, 0.01) -0.73310 -0.73 (-0.61, -0.85) 2 E3 E3 no S1(E3)C
4 1(E3) S D 0.1852 0.02613 0.03 (0.02, 0.04) -0.21540 -0.22 (-0.21, -0.22) 2 E3 E3 no S1(E3)D
... ... ... ... ... ... ... ... ... ... ... ... ... ...
18957 439(E1) H V -0.4753 NaN NaN NaN NaN 988 E1 E1-cytoplasmic no H439(E1)V
18958 439(E1) H W -0.2051 0.23070 0.23 (-0.03, 0.49) -0.28620 -0.29 (-0.64, 0.07) 988 E1 E1-cytoplasmic no H439(E1)W
18959 439(E1) H Y -0.2293 -0.01344 -0.01 (-0.12, 0.1) -0.24560 -0.25 (-0.29, -0.2) 988 E1 E1-cytoplasmic no H439(E1)Y
18960 440(E1) * Q -3.3990 0.13000 0.13 (-0.02, 0.28) -1.51300 -1.51 (-2.55, -0.48) 989 E1 NaN no *440(E1)Q
18961 440(E1) * Y -1.0960 0.64660 0.65 (1.12, 0.17) 0.59920 0.60 (1.22, -0.02) 989 E1 NaN no *440(E1)Y

18962 rows × 13 columns

Simple correlation of binding to different ligands across all mutations¶

In [6]:
# plot the data

site_selection = alt.selection_point(on="mouseover", empty=False, fields=["site"])

mut_selection = alt.selection_point(on="mouseover", empty=False, fields=["mutation"])

min_entry_slider = alt.param(
    name="min_entry_slider",
    bind=alt.binding_range(
        min=data_df["entry"].min(),
        max=0,
        name=f"minimum {entry_name}",
    ),
    value=min_entry,
)

mut_corr_base = alt.Chart(
    data_df[
        ["mutation", "entry", "site"]
        + list(ligands)
        + [f"{lig}_label" for lig in ligands]
    ]
)

for ligand1, ligand2 in itertools.combinations(ligands, 2):
    
    mut_corr_chart = (
        mut_corr_base
        .add_params(site_selection, mut_selection, min_entry_slider)
        .transform_filter(alt.datum["entry"] >= min_entry_slider)
        .encode(
            alt.X(
                ligand1,
                title=f"binding to {ligands[ligand1]}",
                scale=alt.Scale(nice=False, padding=5),
            ),
            alt.Y(
                ligand2,
                title=f"binding to {ligands[ligand2]}",
                scale=alt.Scale(nice=False, padding=5),
            ),
            color=alt.condition(site_selection, alt.value("red"), alt.value("gray")),
            opacity=alt.condition(site_selection, alt.value(0.9), alt.value(0.15)),
            size=alt.condition(site_selection, alt.value(55), alt.value(40)),
            strokeWidth=alt.condition(mut_selection, alt.value(3), alt.value(0.6)),
            tooltip=[
                "mutation",
                alt.Tooltip("entry", format=".2f", title=entry_name),
                alt.Tooltip(f"{ligand1}_label", title=ligands[ligand1]),
                alt.Tooltip(f"{ligand2}_label", title=ligands[ligand2]),
            ],
        )
        .mark_circle(stroke="black")
        .properties(
            width=175,
            height=175,
        )
        .configure_axis(grid=False)
    )

    display(mut_corr_chart)

    print(f"Saving to {mut_corr_chart_html}")
    mut_corr_chart.save(mut_corr_chart_html)
Saving to results/compare_human_mouse_mxra8_mut_binding_corr.html

Plot site effects on binding¶

We pre-filter on the entry cutoff, and then get the summed positive and negative effects at each site for that ligand:

In [7]:
data_filtered_df = data_df.query("entry >= @min_entry")

site_df = (
    data_filtered_df
    .melt(
        id_vars=["site", "sequential_site", "wildtype", "region", "Mxra8 contact"],
        value_vars=ligands,
        var_name="ligand",
        value_name="effect",
    )
    .groupby(
        ["ligand", "site", "sequential_site", "wildtype", "region", "Mxra8 contact"],
        as_index=False,
        dropna=False,
    )
    .aggregate(
        positive_effect=pd.NamedAgg("effect", lambda s: s.clip(lower=0).sum()),
        negative_effect=pd.NamedAgg("effect", lambda s: s.clip(upper=0).sum()),
        absolute_effect=pd.NamedAgg("effect", lambda s: s.abs().sum()),
    )
)
In [8]:
chart_width = 950

site_binding_chart = (
    alt.Chart(
        site_df.assign(ligand_name=lambda x: "binding to " + x["ligand"].map(ligands))
    )
    .encode(
        alt.X(
            "site",
            sort=alt.SortField("sequential_site"),
            axis=alt.Axis(
                values=site_df[["sequential_site", "site"]].sort_values("sequential_site")["site"].iloc[50::130],
                labelAngle=0,
            ),
        ),
        alt.Y("positive_effect", title=None, scale=alt.Scale(nice=False, padding=4)),
        alt.Y2("negative_effect", title=None),
        alt.Color(
            "Mxra8 contact",
            scale=alt.Scale(
                domain=["no", "wrapped", "intraspike", "interspike"],
                range=["gray", "red", "purple", "orange"],
            ),
        ),
        alt.Row(
            "ligand_name",
            title=None,
            header=alt.Header(labelFontStyle="bold", labelPadding=2),
            spacing=5,
        ),
        tooltip=[
            "site",
            "wildtype",
            alt.Tooltip("positive_effect", format=".2f"),
            alt.Tooltip("negative_effect", format=".2f"),
            "Mxra8 contact",
        ],
    )
    .mark_bar(opacity=1, width=2)
    .properties(width=chart_width, height=0.23 * chart_width)
    .resolve_scale(y="independent")
)

Make overlay bar with regions:

In [9]:
region_chart = (
    alt.Chart(site_df[["sequential_site", "region"]].drop_duplicates())
    .encode(
        alt.X("sequential_site:O", axis=None),
        alt.Color(
            "region",
            legend=None,
            scale=alt.Scale(range=["AliceBlue", "CadetBlue", "CadetBlue", "AliceBlue"])
        ),
    )
    .mark_rect(opacity=0.75, strokeWidth=0)
    .properties(width=chart_width)
)

text_df = site_df.groupby("region", as_index=False).aggregate(x=pd.NamedAgg("sequential_site", "mean"))

text_chart = (
    alt.Chart(text_df)
    .encode(
        alt.X(
            "x:Q",
            title=None,
            scale=alt.Scale(domain=(site_df["sequential_site"].min(), site_df["sequential_site"].max())),
            axis=None,
        ),
        alt.Text("region"),
    )
    .mark_text(fontWeight="bold", fontSize=18)
    .properties(width=chart_width, height=21)
)

overlay_chart = region_chart + text_chart

Combine overlay and site chart:

In [10]:
site_chart = (
    alt.vconcat(overlay_chart, site_binding_chart, spacing=1)
    .resolve_scale(color="independent")
    .configure_axis(grid=False, titleFontSize=18, labelFontSize=14)
    .configure_header(labelFontSize=18)
    .configure_view(stroke="black", strokeOpacity=1, strokeWidth=1)
    .configure_legend(labelFontSize=18, titleFontSize=18)
    .interactive(bind_x=True, bind_y=False)
)

print(f"Saving to {site_chart_html}")
site_chart.save(site_chart_html)

site_chart
Saving to results/compare_human_mouse_mxra8_site_chart.html
Out[10]:

Plot correlations in site effects¶

In [11]:
site_corr_df = (
    site_df
    .melt(
        id_vars=["ligand", "site", "wildtype", "region", "Mxra8 contact"],
        value_vars=["positive_effect", "negative_effect", "absolute_effect"],
        var_name="metric",
        value_name="effect",
    )
    .pivot_table(
        index=["site", "wildtype", "region", "Mxra8 contact", "metric"],
        values="effect",
        columns="ligand",
    )
    .reset_index()
)
In [12]:
tooltip_cols = ["site", "wildtype", "region", "Mxra8 contact"]

for ligand1, ligand2 in itertools.combinations(ligands, 2):

    corrs = (
        site_corr_df
        .groupby("metric")
        [[ligand1, ligand2]]
        .corr()
        .reset_index(level=1)
        .query("ligand == @ligand1")
        [ligand2]
        .to_dict()
    )

    site_corr_chart = (
        alt.Chart(
            site_corr_df[tooltip_cols + [ligand1, ligand2, "metric"]]
            .assign(
                metric=lambda x: x["metric"].map(
                    {
                        metric: 
                            f"{metric.replace('_', ' ')} at site (r = {corrs[metric]:.2f})"
                        for metric in site_corr_df["metric"].unique()
                    }
                )
            )
        )
        .add_params(site_selection)
        .encode(
            alt.X(ligand1, title=ligands[ligand1], scale=alt.Scale(nice=False, padding=6)),
            alt.Y(ligand2, title=ligands[ligand2], scale=alt.Scale(nice=False, padding=6)),
            alt.Column(
                "metric",
                title=None,
                header=alt.Header(labelFontStyle="bold", labelFontSize=11, labelPadding=2),
            ),
            color=alt.condition(site_selection, alt.value("red"), alt.value("gray")),
            strokeWidth=alt.condition(site_selection, alt.value(3), alt.value(1)),
            size=alt.condition(site_selection, alt.value(60), alt.value(35)),
            opacity=alt.condition(site_selection, alt.value(1), alt.value(0.25)),
            tooltip=[
                *tooltip_cols,
                alt.Tooltip(ligand1, title=ligands[ligand1], format=".2f"),
                alt.Tooltip(ligand2, title=ligands[ligand2], format=".2f"),
            ],
        )
        .mark_circle(stroke="black")
        .resolve_scale(x="independent", y="independent")
        .configure_axis(grid=False)
        .properties(width=140, height=140)
    )

    display(site_corr_chart)

    print(f"Saving to {site_corr_chart_html}")
    site_corr_chart.save(site_corr_chart_html)
Saving to results/compare_human_mouse_mxra8_site_binding_corr.html
In [ ]: